/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Templates implementing warp-level matrix multiply-accumulate
   operations.
*/

#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/gemm/gemm.h"

#include "cutlass/array_planar_complex.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace warp {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename TileIterator_>
class TileIteratorPlanarComplex {
public:
    /// Underlying iterator over real-valued tiles
    using TileIterator = TileIterator_;

    /// Underlying element type
    using Element = typename TileIterator::Element;

    /// Underlying layout type
    using Layout = typename TileIterator::Layout;

    /// TensorRef type for loading element from a tensor
    using TensorRef = typename TileIterator::TensorRef;

    /// Index type
    using Index = typename TensorRef::Index;

    /// Long Index type
    using LongIndex = typename TensorRef::LongIndex;

    /// Coordinate for an element in the tensor
    using TensorCoord = typename TensorRef::TensorCoord;

    /// Planar complex fragment
    using Fragment =
            ArrayPlanarComplex<Element, TileIterator::Fragment::kElements>;

public:
    /// Underlying tile iterator
    TileIterator tile_iterator_;

    /// Offset (in units of bytes) to the imaginary part of the planar complex
    /// matrix
    LongIndex imaginary_offset_;

public:
    /// Default ctor constructs null iterator
    CUTLASS_HOST_DEVICE
    TileIteratorPlanarComplex() : imaginary_offset_(0) {}

    /// Constructor from TensorRef
    CUTLASS_DEVICE
    TileIteratorPlanarComplex(TensorRef const& ref, int lane_id,
                              LongIndex imaginary_offset)
            : tile_iterator_(ref, lane_id),
              imaginary_offset_(
                      (imaginary_offset * sizeof_bits<Element>::value) / 8) {}

    /// Adds a pointer offset to internal pointer(s) to advance through memory
    CUTLASS_DEVICE
    TileIteratorPlanarComplex& add_pointer_offset(LongIndex offset) {
        tile_iterator_.add_pointer_offset(offset);

        return *this;
    }

    /// Advances an iterator along logical dimensions of matrix in units of
    /// whole tiles
    CUTLASS_HOST_DEVICE
    TileIteratorPlanarComplex& add_tile_offset(TensorCoord const& tile_offset) {
        tile_iterator_.add_tile_offset(tile_offset);

        return *this;
    }

    /// Advances the iterator along the advance dimension
    CUTLASS_DEVICE
    TileIteratorPlanarComplex& operator++() {
        ++tile_iterator_;
        return *this;
    }

    //
    // WIP
    //

    /// Advances the iterator along the opposite of the advance dimension
    CUTLASS_HOST_DEVICE
    TileIteratorPlanarComplex& operator--() {
        --tile_iterator_;
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    TileIteratorPlanarComplex& operator+=(TensorCoord const& tile_offset) {
        tile_iterator_.add_tile_offset(tile_offset);
        return *this;
    }

    ///< advances in units of whole tiles along the logical coordinate space of
    ///< the tensor
    CUTLASS_DEVICE
    TileIteratorPlanarComplex& operator-=(TensorCoord const& tile_offset) {
        tile_iterator_.add_tile_offset(-tile_offset);
        return *this;
    }

    /// Loads a fragment from memory at the location pointed to by the iterator.
    CUTLASS_HOST_DEVICE
    void load(Fragment& frag) const {
        tile_iterator_.load_with_byte_offset(frag.real, 0);
        tile_iterator_.load_with_byte_offset(frag.imag, imaginary_offset_);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset in units of bytes
            Index byte_offset) const {
        tile_iterator_.load_with_byte_offset(frag.real, byte_offset);
        tile_iterator_.load_with_byte_offset(frag.imag,
                                             byte_offset + imaginary_offset_);
    }

    /// Loads a fragment from memory with additional logical offset
    CUTLASS_DEVICE
    void load_with_pointer_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a linear offset
            Index pointer_offset) const {
        Index byte_offset = (pointer_offset * sizeof_bits<Element>::value) / 8;

        tile_iterator_.load_with_byte_offset(frag.real, byte_offset);
        tile_iterator_.load_with_byte_offset(frag.imag,
                                             byte_offset + imaginary_offset_);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset) const {
        tile_iterator_.load_with_byte_offset(frag.real, tile_offset, 0);
        tile_iterator_.load_with_byte_offset(frag.imag, tile_offset,
                                             imaginary_offset_);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index pointer_offset) const {
        Index byte_offset = (pointer_offset * sizeof_bits<Element>::value) / 8;

        tile_iterator_.load_with_byte_offset(frag.real, tile_offset,
                                             byte_offset);
        tile_iterator_.load_with_byte_offset(frag.real, tile_offset,
                                             byte_offset + imaginary_offset_);
    }

    /// Loads a fragment from memory with logical offset in units of whole
    /// tiles.
    CUTLASS_DEVICE
    void load_with_byte_offset(
            /// fragment to load from the tensor
            Fragment& frag,
            /// loads a tile with a logical offset in units of whole tiles
            TensorCoord const& tile_offset,
            /// loads a tile with a logical offset AND a pointer offset
            Index byte_offset) const {
        tile_iterator_.load_with_byte_offset(frag.real, tile_offset,
                                             byte_offset);
        tile_iterator_.load_with_byte_offset(frag.imag, tile_offset,
                                             byte_offset + imaginary_offset_);
    }

    /// Notify the iterator which k-group it is currently pointing to.
    ///
    /// This does not advance the iterator. Rather, it overrides its internal
    /// tracking with constant-valued k-group index to enable the compiler to
    /// fold constants and achieve more efficient code.
    ///
    /// This is used by some nontrivial permuted layouts.
    CUTLASS_DEVICE
    void set_kgroup_index(int k_group) {
        tile_iterator_.set_kgroup_index(k_group);
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace warp
}  // namespace gemm
}  // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
